yunfengzhou-hub commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r857052087


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,736 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input 
evaluated data has columns
+ * rawPrediction, label and an optional weight column. The output metrics may 
contains
+ * 'areaUnderROC', 'areaUnderPR', 'KS' and 'areaUnderLorenz' which will be 
defined by parameter
+ * MetricsNames. Here, we use a parallel method to sort the whole evaluated 
data and calculate the
+ * accurate metrics.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                
BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+    private static final String BOUNDARY_RANGE = "boundaryRange";
+    private static final String PARTITION_SUMMARY = "partitionSummaries";
+    private static final String AREA_UNDER_ROC = "areaUnderROC";
+    private static final String AREA_UNDER_PR = "areaUnderPR";
+    private static final String AREA_UNDER_LORENZ = "areaUnderLorenz";
+    private static final String KS = "KS";
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), 
getRawPredictionCol(), getWeightCol()));
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> 
evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(BOUNDARY_RANGE, 
getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId());
+                        });
+
+        /* Repartitions the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> 
chunkId, x -> x.f3);
+
+        /* Sorts local data.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, 
Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> 
out) {
+                                List<Tuple3<Double, Boolean, Double>> 
bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> 
t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, 
t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o 
-> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint 
: bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, 
weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(PARTITION_SUMMARY, 
partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders());
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders
+                        .keyBy(
+                                (KeySelector<Tuple4<Double, Long, Boolean, 
Double>, Double>)
+                                        value -> value.f0)
+                        .window(EndOfStreamWindows.get())
+                        .apply(
+                                (WindowFunction<
+                                                Tuple4<Double, Long, Boolean, 
Double>,
+                                                double[],
+                                                Double,
+                                                TimeWindow>)
+                                        (key, window, values, out) -> {
+                                            long sum = 0;
+                                            long cnt = 0;
+                                            double positiveSum = 0;
+                                            double negativeSum = 0;
+
+                                            for (Tuple4<Double, Long, Boolean, 
Double> t : values) {
+                                                sum += t.f1;
+                                                cnt++;
+                                                if (t.f2) {
+                                                    positiveSum += t.f3;
+                                                } else {
+                                                    negativeSum += t.f3;
+                                                }
+                                            }
+                                            out.collect(
+                                                    new double[] {
+                                                        1. * sum / cnt * 
positiveSum,
+                                                        positiveSum,
+                                                        negativeSum
+                                                    });
+                                        })

Review Comment:
   Shall we move anonymous functions like this into private static classes and 
add JavaDocs? That might help improve readability.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to