lindong28 commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r859293793


##########
flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java:
##########
@@ -463,5 +463,10 @@ public void testValidators() {
         Assert.assertTrue(nonEmptyArray.validate(new String[] {"1"}));
         Assert.assertFalse(nonEmptyArray.validate(null));
         Assert.assertFalse(nonEmptyArray.validate(new String[0]));
+
+        ParamValidator<String[]> isSubArray = ParamValidators.isSubArray("a", 
"b", "c");
+        Assert.assertFalse(isSubArray.validate(new String[] {"c", "v"}));

Review Comment:
   nits: could we also check the case where the input value is null, to be 
consistent with the above tests?
   
   `Assert.assertFalse(isSubArray.validate(null))`



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorParams.java:
##########
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasRawPredictionCol;
+import org.apache.flink.ml.common.param.HasWeightCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringArrayParam;
+
+/**
+ * Params of BinaryClassificationEvaluator.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface BinaryClassificationEvaluatorParams<T>
+        extends HasLabelCol<T>, HasRawPredictionCol<T>, HasWeightCol<T> {
+    /**
+     * param for metric names in evaluation (supports 'areaUnderROC', 
'areaUnderPR', 'KS' and
+     * 'areaUnderLorenz').
+     *
+     * <p>areaUnderROC: the area under the receiver operating characteristic 
(ROC) curve.

Review Comment:
   nits: could we update the Java doc here to follow the format used in 
`HasHandleInvalid.java`? It seems that the format in `HasHandleInvalid.java` is 
more readable.
   
   Currently the doc for `KS` and `areaUnderLorenz` are on the same lines.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table trainDataTable;
+    private Table trainDataTableWithMultiScore;
+    private Table trainDataTableWithWeight;
+
+    private static final List<Row> TRAIN_DATA =

Review Comment:
   Hmm... since `BinaryClassification` does not do any training (i.e. fit), 
would it be more intuitive to rename this as `INPUT_DATA*`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,731 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data 
has columns
+ * rawPrediction, label and an optional weight column. The output may contain 
different metrics
+ * which will be defined by parameter MetricsNames. See 
@BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator

Review Comment:
   According to Spark's Java doc, its `BinaryClassificationEvaluator` supports 
rawPrediction column of type double (binary 0/1 prediction, or probability of 
label 1), in addition to the rawPrediction column of type vector (length-2 
vector of raw predictions, scores, or label probabilities).
   
   Should we also support rawPrediction column of type double? If not, how do 
we handle those use-case that currently requires rawPrediction column of type 
double in Spark?
   
   



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table trainDataTable;
+    private Table trainDataTableWithMultiScore;
+    private Table trainDataTableWithWeight;
+
+    private static final List<Row> TRAIN_DATA =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.2, 0.8)),
+                    Row.of(1.0, Vectors.dense(0.3, 0.7)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.35, 0.65)),
+                    Row.of(1.0, Vectors.dense(0.45, 0.55)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.65, 0.35)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> TRAIN_DATA_WITH_MULTI_SCORE =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> TRAIN_DATA_WITH_WEIGHT =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0));
+
+    private static final double[] EXPECTED_DATA =

Review Comment:
   Since we check double value equality using `delta=1.0e-5`, would it be 
simpler to reduce the precision of the expected values here accordingly?
   
   Same for other expected values initialized in this file.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,731 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data 
has columns
+ * rawPrediction, label and an optional weight column. The output may contain 
different metrics
+ * which will be defined by parameter MetricsNames. See 
@BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                
BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+    private static final String AREA_UNDER_ROC = "areaUnderROC";

Review Comment:
   Would it be better to move these variables to 
`BinaryClassificationEvaluatorParams` and make them `public`, so that users can 
set the parameter by referencing those variables instead of manually typing a 
string?
   
   Manually typing a string is a bit more error-prone and IDE won't be able to 
give hint in this case.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table trainDataTable;
+    private Table trainDataTableWithMultiScore;
+    private Table trainDataTableWithWeight;
+
+    private static final List<Row> TRAIN_DATA =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.2, 0.8)),
+                    Row.of(1.0, Vectors.dense(0.3, 0.7)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.35, 0.65)),
+                    Row.of(1.0, Vectors.dense(0.45, 0.55)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.65, 0.35)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> TRAIN_DATA_WITH_MULTI_SCORE =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> TRAIN_DATA_WITH_WEIGHT =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0));
+
+    private static final double[] EXPECTED_DATA =
+            new double[] {0.7691481137909708, 0.3714285714285714, 
0.6571428571428571};
+    private static final double[] EXPECTED_DATA_M =
+            new double[] {0.9377705627705628, 0.8571428571428571};
+    private static final double EXPECTED_DATA_W = 0.8911680911680911;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        trainDataTable = 
tEnv.fromDataStream(env.fromCollection(TRAIN_DATA)).as("label", "raw");
+        trainDataTableWithMultiScore =
+                
tEnv.fromDataStream(env.fromCollection(TRAIN_DATA_WITH_MULTI_SCORE))
+                        .as("label", "raw");
+        trainDataTableWithWeight =
+                tEnv.fromDataStream(env.fromCollection(TRAIN_DATA_WITH_WEIGHT))
+                        .as("label", "raw", "weight");
+    }
+
+    @Test
+    public void testParam() {
+        BinaryClassificationEvaluator binaryEval = new 
BinaryClassificationEvaluator();
+        assertEquals("label", binaryEval.getLabelCol());
+        assertNull(binaryEval.getWeightCol());
+        assertEquals("rawPrediction", binaryEval.getRawPredictionCol());
+        assertArrayEquals(
+                new String[] {"areaUnderROC", "areaUnderPR"}, 
binaryEval.getMetricsNames());
+        binaryEval
+                .setLabelCol("labelCol")
+                .setRawPredictionCol("raw")
+                .setMetricsNames("areaUnderROC")
+                .setWeightCol("weight");
+        assertEquals("labelCol", binaryEval.getLabelCol());
+        assertEquals("weight", binaryEval.getWeightCol());
+        assertEquals("raw", binaryEval.getRawPredictionCol());
+        assertArrayEquals(new String[] {"areaUnderROC"}, 
binaryEval.getMetricsNames());
+    }
+
+    @Test
+    public void testSaveLoadAndEvaluate() throws Exception {
+        BinaryClassificationEvaluator eval =
+                new BinaryClassificationEvaluator()
+                        .setMetricsNames("areaUnderPR", "KS", "areaUnderROC")
+                        .setLabelCol("label")
+                        .setRawPredictionCol("raw");
+        BinaryClassificationEvaluator loadedEval =
+                StageTestUtils.saveAndReload(tEnv, eval, 
tempFolder.newFolder().getAbsolutePath());
+        Table evalResult = loadedEval.transform(trainDataTable)[0];
+        DataStream<Row> dataStream = tEnv.toDataStream(evalResult);
+        assertArrayEquals(
+                new String[] {"areaUnderPR", "KS", "areaUnderROC"},
+                evalResult.getResolvedSchema().getColumnNames().toArray());
+        List<Row> results = 
IteratorUtils.toList(dataStream.executeAndCollect());
+        Row result = results.get(0);
+        for (int i = 0; i < EXPECTED_DATA.length; ++i) {
+            assertEquals(EXPECTED_DATA[i], result.getFieldAs(i), 1.0e-5);
+        }
+    }
+
+    @Test
+    public void testEvaluate() throws Exception {
+        BinaryClassificationEvaluator eval =
+                new BinaryClassificationEvaluator()
+                        .setMetricsNames("areaUnderPR", "KS", "areaUnderROC")
+                        .setLabelCol("label")
+                        .setRawPredictionCol("raw");
+        Table evalResult = eval.transform(trainDataTable)[0];
+        DataStream<Row> dataStream = tEnv.toDataStream(evalResult);
+        List<Row> results = 
IteratorUtils.toList(dataStream.executeAndCollect());
+        assertArrayEquals(
+                new String[] {"areaUnderPR", "KS", "areaUnderROC"},
+                evalResult.getResolvedSchema().getColumnNames().toArray());
+        Row result = results.get(0);
+        for (int i = 0; i < EXPECTED_DATA.length; ++i) {
+            assertEquals(EXPECTED_DATA[i], result.getFieldAs(i), 1.0e-5);
+        }
+    }
+
+    @Test
+    public void testEvaluateWithMultiScore() throws Exception {

Review Comment:
   Could you help explain why we call this test `testEvaluateWithMultiScore` 
and call the test above `testEvaluate`? It seems that both tests use multiple 
evaluation metrics and their inputs have the same format.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table trainDataTable;
+    private Table trainDataTableWithMultiScore;
+    private Table trainDataTableWithWeight;
+
+    private static final List<Row> TRAIN_DATA =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.2, 0.8)),
+                    Row.of(1.0, Vectors.dense(0.3, 0.7)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.35, 0.65)),
+                    Row.of(1.0, Vectors.dense(0.45, 0.55)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.65, 0.35)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> TRAIN_DATA_WITH_MULTI_SCORE =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> TRAIN_DATA_WITH_WEIGHT =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0));
+
+    private static final double[] EXPECTED_DATA =
+            new double[] {0.7691481137909708, 0.3714285714285714, 
0.6571428571428571};
+    private static final double[] EXPECTED_DATA_M =
+            new double[] {0.9377705627705628, 0.8571428571428571};
+    private static final double EXPECTED_DATA_W = 0.8911680911680911;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        trainDataTable = 
tEnv.fromDataStream(env.fromCollection(TRAIN_DATA)).as("label", "raw");
+        trainDataTableWithMultiScore =
+                
tEnv.fromDataStream(env.fromCollection(TRAIN_DATA_WITH_MULTI_SCORE))
+                        .as("label", "raw");
+        trainDataTableWithWeight =
+                tEnv.fromDataStream(env.fromCollection(TRAIN_DATA_WITH_WEIGHT))
+                        .as("label", "raw", "weight");
+    }
+
+    @Test
+    public void testParam() {
+        BinaryClassificationEvaluator binaryEval = new 
BinaryClassificationEvaluator();
+        assertEquals("label", binaryEval.getLabelCol());
+        assertNull(binaryEval.getWeightCol());
+        assertEquals("rawPrediction", binaryEval.getRawPredictionCol());
+        assertArrayEquals(
+                new String[] {"areaUnderROC", "areaUnderPR"}, 
binaryEval.getMetricsNames());
+        binaryEval
+                .setLabelCol("labelCol")
+                .setRawPredictionCol("raw")
+                .setMetricsNames("areaUnderROC")
+                .setWeightCol("weight");
+        assertEquals("labelCol", binaryEval.getLabelCol());
+        assertEquals("weight", binaryEval.getWeightCol());
+        assertEquals("raw", binaryEval.getRawPredictionCol());
+        assertArrayEquals(new String[] {"areaUnderROC"}, 
binaryEval.getMetricsNames());
+    }
+
+    @Test
+    public void testSaveLoadAndEvaluate() throws Exception {
+        BinaryClassificationEvaluator eval =
+                new BinaryClassificationEvaluator()
+                        .setMetricsNames("areaUnderPR", "KS", "areaUnderROC")
+                        .setLabelCol("label")

Review Comment:
   Since the goal of this test is to test save/load/transform, would it be 
simpler to keep using the default label parameter value (i.e. `label`) without 
explicitly setting it? Similarly, can we keep using the default value for 
rawPredictionCol?
   
   Same for other tests whose purpose is not to test the parameter set/get.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to