This is an automated email from the ASF dual-hosted git repository.

zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new 5619c3b8 [FLINK-32889] Fix calcuation of weighted areaUnderROC and 
areaUnderPRC in BinaryClassificationEvaluator
5619c3b8 is described below

commit 5619c3b8591b220e78a0a792c1f940e06149c8f0
Author: Fan Hong <[email protected]>
AuthorDate: Thu Aug 24 14:40:12 2023 +0800

    [FLINK-32889] Fix calcuation of weighted areaUnderROC and areaUnderPRC in 
BinaryClassificationEvaluator
    
    This closes #252.
---
 .../BinaryClassificationEvaluator.java             | 298 +++++----------------
 .../BinaryClassificationEvaluatorTest.java         |  17 +-
 .../evaluation/tests/tests_binaryclassification.py |  10 +-
 3 files changed, 86 insertions(+), 239 deletions(-)

diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java
index d74e40b2..051d7513 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java
@@ -20,8 +20,6 @@ package org.apache.flink.ml.evaluation.binaryclassification;
 
 import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.functions.MapPartitionFunction;
-import org.apache.flink.api.common.functions.ReduceFunction;
-import org.apache.flink.api.common.functions.RichFlatMapFunction;
 import org.apache.flink.api.common.functions.RichMapFunction;
 import org.apache.flink.api.common.functions.RichMapPartitionFunction;
 import org.apache.flink.api.common.state.ListState;
@@ -64,7 +62,6 @@ import java.util.Collections;
 import java.util.Comparator;
 import java.util.HashMap;
 import java.util.Iterator;
-import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.Random;
@@ -124,7 +121,7 @@ public class BinaryClassificationEvaluator
                                     Iterable<Tuple4<Double, Boolean, Double, 
Integer>> values,
                                     Collector<Tuple3<Double, Boolean, Double>> 
out) {
                                 List<Tuple3<Double, Boolean, Double>> 
bufferedData =
-                                        new LinkedList<>();
+                                        new ArrayList<>();
                                 for (Tuple4<Double, Boolean, Double, Integer> 
t4 : values) {
                                     bufferedData.add(Tuple3.of(t4.f0, t4.f1, 
t4.f2));
                                 }
@@ -142,48 +139,8 @@ public class BinaryClassificationEvaluator
                         TypeInformation.of(BinarySummary.class),
                         new PartitionSummaryOperator());
 
-        /* Sorts global data. Output Tuple4 : <score, order, isPositive, 
weight>. */
-        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
-                BroadcastUtils.withBroadcastStream(
-                        Collections.singletonList(sortEvalData),
-                        Collections.singletonMap(partitionSummariesKey, 
partitionSummaries),
-                        inputList -> {
-                            DataStream input = inputList.get(0);
-                            return input.flatMap(new 
CalcSampleOrders(partitionSummariesKey));
-                        });
-
-        DataStream<double[]> localAreaUnderROCVariable =
-                dataWithOrders.transform(
-                        "AccumulateMultiScore",
-                        TypeInformation.of(double[].class),
-                        new AccumulateMultiScoreOperator());
-
-        DataStream<double[]> middleAreaUnderROC =
-                DataStreamUtils.reduce(
-                        localAreaUnderROCVariable,
-                        (ReduceFunction<double[]>)
-                                (t1, t2) -> {
-                                    t2[0] += t1[0];
-                                    t2[1] += t1[1];
-                                    t2[2] += t1[2];
-                                    return t2;
-                                });
-
-        DataStream<Double> areaUnderROC =
-                middleAreaUnderROC.map(
-                        (MapFunction<double[], Double>)
-                                value -> {
-                                    if (value[1] > 0 && value[2] > 0) {
-                                        return (value[0] - 1. * value[1] * 
(value[1] + 1) / 2)
-                                                / (value[1] * value[2]);
-                                    } else {
-                                        return Double.NaN;
-                                    }
-                                });
-
         Map<String, DataStream<?>> broadcastMap = new HashMap<>();
         broadcastMap.put(partitionSummariesKey, partitionSummaries);
-        broadcastMap.put(AREA_UNDER_ROC, areaUnderROC);
         DataStream<BinaryMetrics> localMetrics =
                 BroadcastUtils.withBroadcastStream(
                         Collections.singletonList(sortEvalData),
@@ -218,89 +175,6 @@ public class BinaryClassificationEvaluator
         return new Table[] {tEnv.fromDataStream(evalResult)};
     }
 
-    /** Updates variables for calculating AreaUnderROC. */
-    private static class AccumulateMultiScoreOperator extends 
AbstractStreamOperator<double[]>
-            implements OneInputStreamOperator<Tuple4<Double, Long, Boolean, 
Double>, double[]>,
-                    BoundedOneInput {
-        private ListState<double[]> accValueState;
-        private ListState<Double> scoreState;
-
-        double[] accValue;
-        double score;
-
-        @Override
-        public void endInput() {
-            if (accValue != null) {
-                output.collect(
-                        new StreamRecord<>(
-                                new double[] {
-                                    accValue[0] / accValue[1] * accValue[2],
-                                    accValue[2],
-                                    accValue[3]
-                                }));
-            }
-        }
-
-        @Override
-        public void processElement(
-                StreamRecord<Tuple4<Double, Long, Boolean, Double>> 
streamRecord) {
-            Tuple4<Double, Long, Boolean, Double> t = streamRecord.getValue();
-            if (accValue == null) {
-                accValue = new double[4];
-                score = t.f0;
-            } else if (score != t.f0) {
-                output.collect(
-                        new StreamRecord<>(
-                                new double[] {
-                                    accValue[0] / accValue[1] * accValue[2],
-                                    accValue[2],
-                                    accValue[3]
-                                }));
-                Arrays.fill(accValue, 0.0);
-            }
-            accValue[0] += t.f1;
-            accValue[1] += 1.0;
-            if (t.f2) {
-                accValue[2] += t.f3;
-            } else {
-                accValue[3] += t.f3;
-            }
-        }
-
-        @Override
-        @SuppressWarnings("unchecked")
-        public void initializeState(StateInitializationContext context) throws 
Exception {
-            super.initializeState(context);
-            accValueState =
-                    context.getOperatorStateStore()
-                            .getListState(
-                                    new ListStateDescriptor<>(
-                                            "accValueState", 
TypeInformation.of(double[].class)));
-            accValue =
-                    OperatorStateUtils.getUniqueElement(accValueState, 
"accValueState")
-                            .orElse(null);
-
-            scoreState =
-                    context.getOperatorStateStore()
-                            .getListState(
-                                    new ListStateDescriptor<>(
-                                            "scoreState", 
TypeInformation.of(Double.class)));
-            score = OperatorStateUtils.getUniqueElement(scoreState, 
"scoreState").orElse(0.0);
-        }
-
-        @Override
-        @SuppressWarnings("unchecked")
-        public void snapshotState(StateSnapshotContext context) throws 
Exception {
-            super.snapshotState(context);
-            accValueState.clear();
-            scoreState.clear();
-            if (accValue != null) {
-                accValueState.add(accValue);
-                scoreState.add(score);
-            }
-        }
-    }
-
     private static class PartitionSummaryOperator extends 
AbstractStreamOperator<BinarySummary>
             implements OneInputStreamOperator<Tuple3<Double, Boolean, Double>, 
BinarySummary>,
                     BoundedOneInput {
@@ -320,7 +194,6 @@ public class BinaryClassificationEvaluator
         }
 
         @Override
-        @SuppressWarnings("unchecked")
         public void initializeState(StateInitializationContext context) throws 
Exception {
             super.initializeState(context);
             summaryState =
@@ -340,7 +213,6 @@ public class BinaryClassificationEvaluator
         }
 
         @Override
-        @SuppressWarnings("unchecked")
         public void snapshotState(StateSnapshotContext context) throws 
Exception {
             super.snapshotState(context);
             summaryState.clear();
@@ -385,24 +257,24 @@ public class BinaryClassificationEvaluator
 
             List<BinarySummary> statistics =
                     
getRuntimeContext().getBroadcastVariable(partitionSummariesKey);
-            long[] countValues =
+            double[] accWeights =
                     reduceBinarySummary(statistics, 
getRuntimeContext().getIndexOfThisSubtask());
 
-            double areaUnderROC =
-                    
getRuntimeContext().<Double>getBroadcastVariable(AREA_UNDER_ROC).get(0);
-            long totalTrue = countValues[2];
-            long totalFalse = countValues[3];
-            if (totalTrue == 0) {
-                LOG.warn("There is no positive sample in data!");
+            double totalSumWeightsPos = accWeights[2];
+            double totalSumWeightsNeg = accWeights[3];
+            if (totalSumWeightsPos == 0) {
+                LOG.warn("There is no positive samples in data!");
             }
-            if (totalFalse == 0) {
-                LOG.warn("There is no negative sample in data!");
+            if (totalSumWeightsNeg == 0) {
+                LOG.warn("There is no negative samples in data!");
             }
 
-            BinaryMetrics metrics = new BinaryMetrics(0L, areaUnderROC);
+            BinaryMetrics metrics = new BinaryMetrics(0);
+            // Stores values of TPR, FPR, Precision, and PR calculated from 
samples with scores
+            // ranging from the maximum to the current one.
             double[] tprFprPrecision = new double[4];
             for (Tuple3<Double, Boolean, Double> t3 : iterable) {
-                updateBinaryMetrics(t3, metrics, countValues, tprFprPrecision);
+                updateBinaryMetrics(t3, metrics, accWeights, tprFprPrecision);
             }
             collector.collect(metrics);
         }
@@ -411,35 +283,36 @@ public class BinaryClassificationEvaluator
     private static void updateBinaryMetrics(
             Tuple3<Double, Boolean, Double> cur,
             BinaryMetrics binaryMetrics,
-            long[] countValues,
+            double[] accWeights,
             double[] recordValues) {
-        if (binaryMetrics.count == 0) {
-            recordValues[0] = countValues[2] == 0 ? 1.0 : 1.0 * countValues[0] 
/ countValues[2];
-            recordValues[1] = countValues[3] == 0 ? 1.0 : 1.0 * countValues[1] 
/ countValues[3];
+        if (binaryMetrics.sumWeights == 0) {
+            recordValues[0] = accWeights[2] == 0 ? 1.0 : accWeights[0] / 
accWeights[2];
+            recordValues[1] = accWeights[3] == 0 ? 1.0 : accWeights[1] / 
accWeights[3];
             recordValues[2] =
-                    countValues[0] + countValues[1] == 0
+                    accWeights[0] + accWeights[1] == 0
                             ? 1.0
-                            : 1.0 * countValues[0] / (countValues[0] + 
countValues[1]);
-            recordValues[3] =
-                    1.0 * (countValues[0] + countValues[1]) / (countValues[2] 
+ countValues[3]);
+                            : accWeights[0] / (accWeights[0] + accWeights[1]);
+            recordValues[3] = (accWeights[0] + accWeights[1]) / (accWeights[2] 
+ accWeights[3]);
         }
 
-        binaryMetrics.count++;
-        if (cur.f1) {
-            countValues[0]++;
+        boolean isPos = cur.f1;
+        double weight = cur.f2;
+        binaryMetrics.sumWeights += weight;
+        if (isPos) {
+            accWeights[0] += weight;
         } else {
-            countValues[1]++;
+            accWeights[1] += weight;
         }
 
-        double tpr = countValues[2] == 0 ? 1.0 : 1.0 * countValues[0] / 
countValues[2];
-        double fpr = countValues[3] == 0 ? 1.0 : 1.0 * countValues[1] / 
countValues[3];
+        double tpr = accWeights[2] == 0 ? 1.0 : accWeights[0] / accWeights[2];
+        double fpr = accWeights[3] == 0 ? 1.0 : accWeights[1] / accWeights[3];
         double precision =
-                countValues[0] + countValues[1] == 0
+                accWeights[0] + accWeights[1] == 0
                         ? 1.0
-                        : 1.0 * countValues[0] / (countValues[0] + 
countValues[1]);
-        double positiveRate =
-                1.0 * (countValues[0] + countValues[1]) / (countValues[2] + 
countValues[3]);
+                        : accWeights[0] / (accWeights[0] + accWeights[1]);
+        double positiveRate = (accWeights[0] + accWeights[1]) / (accWeights[2] 
+ accWeights[3]);
 
+        binaryMetrics.areaUnderROC += (fpr - recordValues[1]) * (tpr + 
recordValues[0]) / 2;
         binaryMetrics.areaUnderLorenz +=
                 ((positiveRate - recordValues[3]) * (tpr + recordValues[0]) / 
2);
         binaryMetrics.areaUnderPR += ((tpr - recordValues[0]) * (precision + 
recordValues[2]) / 2);
@@ -451,65 +324,31 @@ public class BinaryClassificationEvaluator
         recordValues[3] = positiveRate;
     }
 
-    /**
-     * For each sample, calculates its score order among all samples. The 
sample with minimum score
-     * has order 1, while the sample with maximum score has order samples.
-     *
-     * <p>Input is a dataset of tuple (score, is real positive, weight), 
output is a dataset of
-     * tuple (score, order, is real positive, weight).
-     */
-    private static class CalcSampleOrders
-            extends RichFlatMapFunction<
-                    Tuple3<Double, Boolean, Double>, Tuple4<Double, Long, 
Boolean, Double>> {
-        private long startIndex;
-        private long total = -1;
-        private final String partitionSummariesKey;
-
-        public CalcSampleOrders(String partitionSummariesKey) {
-            this.partitionSummariesKey = partitionSummariesKey;
-        }
-
-        @Override
-        public void flatMap(
-                Tuple3<Double, Boolean, Double> value,
-                Collector<Tuple4<Double, Long, Boolean, Double>> out)
-                throws Exception {
-            if (total == -1) {
-                List<BinarySummary> statistics =
-                        
getRuntimeContext().getBroadcastVariable(partitionSummariesKey);
-                long[] countValues =
-                        reduceBinarySummary(
-                                statistics, 
getRuntimeContext().getIndexOfThisSubtask());
-                startIndex = countValues[1] + countValues[0] + 1;
-                total = countValues[2] + countValues[3];
-            }
-            out.collect(Tuple4.of(value.f0, total - startIndex + 1, value.f1, 
value.f2));
-            startIndex++;
-        }
-    }
-
     /**
      * @param values Reduce Summary of all workers.
      * @param taskId current taskId.
-     * @return [curTrue, curFalse, TotalTrue, TotalFalse]
+     * @return An array storing sum of weights of positives/negatives of tasks 
before the current
+     *     one, and sum of weights of positives/negatives of all tasks.
      */
-    private static long[] reduceBinarySummary(List<BinarySummary> values, int 
taskId) {
+    private static double[] reduceBinarySummary(List<BinarySummary> values, 
int taskId) {
         List<BinarySummary> list = new ArrayList<>(values);
         list.sort(Comparator.comparingDouble(t -> -t.maxScore));
-        long curTrue = 0;
-        long curFalse = 0;
-        long totalTrue = 0;
-        long totalFalse = 0;
+        double prefixSumWeightsPos = 0;
+        double prefixSumWeightsNeg = 0;
+        double totalSumWeightsPos = 0;
+        double totalSumWeightsNeg = 0;
 
         for (BinarySummary statistics : list) {
             if (statistics.taskId == taskId) {
-                curFalse = totalFalse;
-                curTrue = totalTrue;
+                prefixSumWeightsNeg = totalSumWeightsNeg;
+                prefixSumWeightsPos = totalSumWeightsPos;
             }
-            totalTrue += statistics.curPositive;
-            totalFalse += statistics.curNegative;
+            totalSumWeightsPos += statistics.sumWeightsPos;
+            totalSumWeightsNeg += statistics.sumWeightsNeg;
         }
-        return new long[] {curTrue, curFalse, totalTrue, totalFalse};
+        return new double[] {
+            prefixSumWeightsPos, prefixSumWeightsNeg, totalSumWeightsPos, 
totalSumWeightsNeg
+        };
     }
 
     /**
@@ -520,13 +359,16 @@ public class BinaryClassificationEvaluator
      */
     private static void updateBinarySummary(
             BinarySummary statistics, Tuple3<Double, Boolean, Double> 
evalElement) {
-        if (evalElement.f1) {
-            statistics.curPositive++;
+        boolean isPos = evalElement.f1;
+        double weight = evalElement.f2;
+        double score = evalElement.f0;
+        if (isPos) {
+            statistics.sumWeightsPos += weight;
         } else {
-            statistics.curNegative++;
+            statistics.sumWeightsNeg += weight;
         }
-        if (Double.compare(statistics.maxScore, evalElement.f0) < 0) {
-            statistics.maxScore = evalElement.f0;
+        if (Double.compare(statistics.maxScore, score) < 0) {
+            statistics.maxScore = score;
         }
     }
 
@@ -673,25 +515,26 @@ public class BinaryClassificationEvaluator
         public Integer taskId;
         // maximum score in this partition
         public double maxScore;
-        // real positives in this partition
-        public long curPositive;
-        // real negatives in this partition
-        public long curNegative;
+        // sum of weights of positives in this partition
+        public double sumWeightsPos;
+        // sum of weights of negatives in this partition
+        public double sumWeightsNeg;
 
         public BinarySummary() {}
 
-        public BinarySummary(Integer taskId, double maxScore, long 
curPositive, long curNegative) {
+        public BinarySummary(
+                Integer taskId, double maxScore, double sumWeightsPos, double 
sumWeightsNeg) {
             this.taskId = taskId;
             this.maxScore = maxScore;
-            this.curPositive = curPositive;
-            this.curNegative = curNegative;
+            this.sumWeightsPos = sumWeightsPos;
+            this.sumWeightsNeg = sumWeightsNeg;
         }
     }
 
     /** The evaluation metrics for binary classification. */
     public static class BinaryMetrics {
-        /* The count of samples. */
-        public long count;
+        /* The sum of weights of samples. */
+        public double sumWeights;
 
         /* Area under ROC */
         public double areaUnderROC;
@@ -707,22 +550,19 @@ public class BinaryClassificationEvaluator
 
         public BinaryMetrics() {}
 
-        public BinaryMetrics(long count, double areaUnderROC) {
-            this.count = count;
-            this.areaUnderROC = areaUnderROC;
+        public BinaryMetrics(long sumWeights) {
+            this.sumWeights = sumWeights;
         }
 
         public BinaryMetrics merge(BinaryMetrics binaryClassMetrics) {
             if (null == binaryClassMetrics) {
                 return this;
             }
-            Preconditions.checkState(
-                    Double.compare(areaUnderROC, 
binaryClassMetrics.areaUnderROC) == 0,
-                    "AreaUnderROC not equal!");
-            count += binaryClassMetrics.count;
-            ks = Math.max(ks, binaryClassMetrics.ks);
-            areaUnderPR += binaryClassMetrics.areaUnderPR;
+            sumWeights += binaryClassMetrics.sumWeights;
+            areaUnderROC += binaryClassMetrics.areaUnderROC;
             areaUnderLorenz += binaryClassMetrics.areaUnderLorenz;
+            areaUnderPR += binaryClassMetrics.areaUnderPR;
+            ks = Math.max(ks, binaryClassMetrics.ks);
             return this;
         }
     }
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java
 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java
index 0c146a3d..f6f38b29 100644
--- 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java
@@ -118,7 +118,8 @@ public class BinaryClassificationEvaluatorTest extends 
AbstractTestBase {
             new double[] {
                 0.8571428571428571, 0.9377705627705628, 0.8571428571428571, 
0.6488095238095237
             };
-    private static final double EXPECTED_DATA_W = 0.8911680911680911;
+    private static final double[] EXPECTED_DATA_W =
+            new double[] {0.8717948717948718, 0.9510202726261435};
     private static final double EPS = 1.0e-5;
 
     @Before
@@ -297,14 +298,20 @@ public class BinaryClassificationEvaluatorTest extends 
AbstractTestBase {
     public void testEvaluateWithWeight() {
         BinaryClassificationEvaluator eval =
                 new BinaryClassificationEvaluator()
-                        
.setMetricsNames(BinaryClassificationEvaluatorParams.AREA_UNDER_ROC)
+                        .setMetricsNames(
+                                
BinaryClassificationEvaluatorParams.AREA_UNDER_ROC,
+                                
BinaryClassificationEvaluatorParams.AREA_UNDER_PR)
                         .setWeightCol("weight");
         Table evalResult = eval.transform(inputDataTableWithWeight)[0];
-        List<Row> results = 
IteratorUtils.toList(evalResult.execute().collect());
+        Row result = (Row) 
IteratorUtils.toList(evalResult.execute().collect()).get(0);
         assertArrayEquals(
-                new String[] 
{BinaryClassificationEvaluatorParams.AREA_UNDER_ROC},
+                new String[] {
+                    BinaryClassificationEvaluatorParams.AREA_UNDER_ROC,
+                    BinaryClassificationEvaluatorParams.AREA_UNDER_PR
+                },
                 evalResult.getResolvedSchema().getColumnNames().toArray());
-        assertEquals(EXPECTED_DATA_W, results.get(0).getFieldAs(0), EPS);
+        assertArrayEquals(
+                EXPECTED_DATA_W, new double[] {result.getFieldAs(0), 
result.getFieldAs(1)}, EPS);
     }
 
     @Test
diff --git 
a/flink-ml-python/pyflink/ml/evaluation/tests/tests_binaryclassification.py 
b/flink-ml-python/pyflink/ml/evaluation/tests/tests_binaryclassification.py
index 7f34eb4c..55a8f2b9 100644
--- a/flink-ml-python/pyflink/ml/evaluation/tests/tests_binaryclassification.py
+++ b/flink-ml-python/pyflink/ml/evaluation/tests/tests_binaryclassification.py
@@ -16,10 +16,10 @@
 # limitations under the License.
 
################################################################################
 import os
-
 from pyflink.common import Types
-from pyflink.ml.linalg import Vectors, DenseVectorTypeInfo
+
 from pyflink.ml.evaluation.binaryclassification import 
BinaryClassificationEvaluator
+from pyflink.ml.linalg import Vectors, DenseVectorTypeInfo
 from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
 
 
@@ -111,7 +111,7 @@ class BinaryClassificationEvaluatorTest(PyFlinkMLTestCase):
         self.expected_data_m = [0.8571428571428571, 0.9377705627705628,
                                 0.8571428571428571, 0.6488095238095237]
 
-        self.expected_data_w = 0.8911680911680911
+        self.expected_data_w = [0.8717948717948718, 0.9510202726261435]
 
         self.eps = 1e-5
 
@@ -185,11 +185,11 @@ class 
BinaryClassificationEvaluatorTest(PyFlinkMLTestCase):
 
     def test_evaluate_with_weight(self):
         evaluator = BinaryClassificationEvaluator() \
-            .set_metrics_names("areaUnderROC") \
+            .set_metrics_names("areaUnderROC", "areaUnderPR") \
             .set_weight_col("weight")
         output = evaluator.transform(self.input_data_table_with_weight)[0]
         self.assertEqual(
-            ["areaUnderROC"],
+            ["areaUnderROC", "areaUnderPR"],
             output.get_schema().get_field_names())
         results = [result for result in output.execute().collect()]
         result = results[0]

Reply via email to